# Author: WY
# Date: 2023/07/12 15:16
# Describe: 连续的超参数性能测试

import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pylab as plt
from sklearn import metrics
from pylab import mpl
import os
import pickle

# 防止plot时报错
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 设置显示中文字体
mpl.rcParams['font.sans-serif'] = ['SimHei']
# 设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False
# cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数
scopes = 100 # 40
lr = 0.0002
split = 0.8
dropout = 0.2 # 0.2
batch_size = 40
in_channels = 1
out_channels = 52 # 52
kernel_size = 5
stride = 1
feature_dim = 13
output_size = 1
hidden_size = 39 # 39
num_layers = 3 # 3

# 模型文件路径
# pkl_path = 'result/temperature/results.pkl'
# losses_img_path = 'result/temperature/losses.png'
# bias_img_path = 'result/temperature/bias.png'
# variance_img_path = 'result/temperature/variance.png'
# val_path = 'result/temperature/val.txt'

# 模型
class Model(nn.Module):
    def __init__(self, step):
        super(Model, self).__init__()
        # 预测步长
        self.pre_step = step
        # 卷积层
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, padding='same')
        # Relu
        self.relu = nn.ReLU()
        # MaxPooling层
        self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) / 2))
        # Dropout层
        self.dropout = nn.Dropout2d(dropout)
        # GRU层
        self.gru = nn.GRU(
            input_size=feature_dim * out_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        for p in self.gru.parameters():
            nn.init.normal_(p, mean=0.0, std=0.001)
        # Attention
        # self.attention = nn.MultiheadAttention(hidden_size, num_heads=1)
        # 全连接层
        self.linear = nn.Linear(hidden_size, output_size)
        self.fc = nn.Linear(batch_size, self.pre_step)

    def forward(self, input):
        # reshape [batch, channel, Height, Width]
        input = input.reshape([1, in_channels, batch_size, feature_dim])
        # 卷积层
        input = self.conv(input)
        # Relu
        input = self.relu(input)
        # 池化层
        input = self.pool(input)
        # dropout
        input = self.dropout(input)

        # reshape
        input = input.reshape(1, batch_size, feature_dim * out_channels)
        # GRU
        out, _ = self.gru(input)

        # attention
        # query, key, value = out, out, out
        # # [batch, batch_size, features]
        # out, _ = self.attention(query, key, value)

        # 降维度
        out = torch.squeeze(out, 0)

        # 得出结论
        out = self.linear(out)
        out = self.fc(out.T)

        return out.T


# 训练
def train(dataloader_train, step, pkl_path, losses_img_path, val_path):
    # 损失值记录容器
    losses = []
    loss_record = []
    # 实例化模型,损失函数,优化器
    model = Model(step).to(device)
    loss_fun = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr)
    # 加载训练成果
    if os.path.exists(pkl_path):
        model.load_state_dict(torch.load(pkl_path))
    # 记录时间
    start_date = datetime.datetime.now()
    # 开始训练
    print("start")
    for scope in range(scopes):
        for batch_idx, (data) in enumerate(dataloader_train):
            # [channel, batch_size - pre_step, input_size]
            input = data[0:batch_size].float().view(1, batch_size, feature_dim).to(device)
            # [pre_step, output_size]
            label = data[batch_size:, 0].float().view(step, output_size).to(device)

            output = model(input)
            loss = loss_fun(output, label)
            model.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                losses.append(loss.item())

        sum = np.sum(losses)
        loss_record.append(sum)
        print(f"scope: {scope}, losses: {sum}")
        losses.clear()

    # 记录时间
    end_date = datetime.datetime.now()
    # 保存训练参数结果,如果目标地址父文件夹不存在就先创建再保存,存在就跳过创建直接保存
    os.makedirs(os.path.dirname(pkl_path), exist_ok=True)
    torch.save(model.state_dict(), pkl_path)
    # 损失函数图像
    fig = plt.figure()
    plt.plot(np.array(range(len(loss_record))), loss_record)
    plt.title(label=f'损失值变化 训练耗时:{end_date - start_date}')
    plt.show()
    # 保存图像
    fig.savefig(losses_img_path)
    # 打印损失值
    np.savetxt(val_path, loss_record, fmt='%.7f', delimiter=',')


# 用训练集测模型偏差
def bias_test(dataloader_train, step, pkl_path, bias_img_path):
    # 接受array
    outputs = []
    labels = []
    # 获取模型实例,加载训练成果
    model = Model(step).to(device)
    # 加载训练成果
    if os.path.exists(pkl_path):
        model.load_state_dict(torch.load(pkl_path))

    for batch_idx, (data) in enumerate(dataloader_train):
        input = data[0:batch_size].float().view(1, batch_size, feature_dim).to(device)
        label = data[batch_size:, 0].float().view(step, output_size).tolist()

        output = model(input)

        # 取出预测值
        output = output.cpu().detach().tolist()
        outputs.extend(output)
        labels.extend(label)

    # 打印预测结果和label
    xAxis = np.array(range(len(outputs)))
    yAxis_predict = np.array(outputs)
    yAxis_label = np.array(labels)
    fig = plt.figure()
    plt.plot(xAxis, yAxis_label, 'r-.', label='label')
    plt.plot(xAxis, yAxis_predict, 'b-.', label='predict')
    plt.title(label='模型偏差')
    plt.show()
    # 保存图像
    fig.savefig(bias_img_path)


# 用测试集测模型方差
def variance_test(dataloader_test, step, pkl_path, variance_img_path):
    # 接受array
    outputs = []
    labels = []
    # 获取模型实例,加载训练成果
    model = Model(step).to(device)
    # 加载训练成果
    if os.path.exists(pkl_path):
        model.load_state_dict(torch.load(pkl_path))

    for batch_idx, (data) in enumerate(dataloader_test):
        input = data[0:batch_size].float().view(1, batch_size, feature_dim).to(device)
        label = data[batch_size:, 0].view(step, output_size).tolist()

        output = model(input)

        output = output.cpu().detach().tolist()
        outputs.extend(output)
        labels.extend(label)

    # 打印预测结果和label
    xAxis = np.array(range(len(outputs)))
    yAxis_predict = np.array(outputs)
    yAxis_label = np.array(labels)
    # RMSE
    RMSE = np.sqrt(metrics.mean_squared_error(yAxis_label, yAxis_predict))
    # MAE
    MAE = metrics.mean_absolute_error(yAxis_label, yAxis_predict)
    # 绘图
    fig = plt.figure()
    plt.plot(xAxis, yAxis_label, 'r-.', label='真实值')
    plt.plot(xAxis, yAxis_predict, 'b-.', label='预测值')
    plt.title(label=f"模型方差 RMSE={RMSE} MAE={MAE}")
    plt.legend(loc='upper right')
    plt.show()
    # 保存图像
    fig.savefig(variance_img_path)


def performance_record(step, dataloader_train, dataloader_test):
    print(f"正在处理预测步长为{step}的模型流程")
    # 动态地址
    pkl_path = f'result/temperature/pre_step = {step}/results.pkl'
    losses_img_path = f'result/temperature/pre_step = {step}/losses.png'
    bias_img_path = f'result/temperature/pre_step = {step}/bias.png'
    variance_img_path = f'result/temperature/pre_step = {step}/variance.png'
    val_path = f'result/temperature/pre_step = {step}/val.txt'
    # 带入模型流程,记录结果
    train(dataloader_train, step, pkl_path, losses_img_path, val_path)
    bias_test(dataloader_train, step, pkl_path, bias_img_path)
    variance_test(dataloader_test, step, pkl_path, variance_img_path)

